{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 决策树解决回归问题"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sklearn import datasets\n",
    "\n",
    "boston = datasets.load_boston()\n",
    "X = boston.data\n",
    "y = boston.target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Decision Tree Regressor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "DecisionTreeRegressor(ccp_alpha=0.0, criterion='mse', max_depth=None,\n                      max_features=None, max_leaf_nodes=None,\n                      min_impurity_decrease=0.0, min_impurity_split=None,\n                      min_samples_leaf=1, min_samples_split=2,\n                      min_weight_fraction_leaf=0.0, presort='deprecated',\n                      random_state=None, splitter='best')"
     },
     "metadata": {},
     "execution_count": 4
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeRegressor\n",
    "\n",
    "dt_reg = DecisionTreeRegressor()\n",
    "dt_reg.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "0.5809086915846111"
     },
     "metadata": {},
     "execution_count": 5
    }
   ],
   "source": [
    "dt_reg.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "1.0"
     },
     "metadata": {},
     "execution_count": 7
    }
   ],
   "source": [
    "dt_reg.score(X_train, y_train)\n",
    "# 训练数据集得到1.0的分数，出现过拟合现象"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "0.586506915504435\n0.9880424602693549\n"
    }
   ],
   "source": [
    "dt_reg2 = DecisionTreeRegressor(max_depth=10)\n",
    "dt_reg2.fit(X_train, y_train)\n",
    "print(dt_reg2.score(X_test, y_test))\n",
    "print(dt_reg2.score(X_train, y_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "Fitting 5 folds for each of 20736 candidates, totalling 103680 fits\n[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n[Parallel(n_jobs=-1)]: Done  25 tasks      | elapsed:    1.0s\n[Parallel(n_jobs=-1)]: Done 2696 tasks      | elapsed:    2.6s\n[Parallel(n_jobs=-1)]: Done 15688 tasks      | elapsed:    8.2s\n[Parallel(n_jobs=-1)]: Done 33800 tasks      | elapsed:   16.5s\n[Parallel(n_jobs=-1)]: Done 57160 tasks      | elapsed:   28.0s\n[Parallel(n_jobs=-1)]: Done 85640 tasks      | elapsed:   42.3s\n[Parallel(n_jobs=-1)]: Done 103680 out of 103680 | elapsed:   51.2s finished\n"
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "GridSearchCV(cv=None, error_score=nan,\n             estimator=DecisionTreeRegressor(ccp_alpha=0.0, criterion='mse',\n                                             max_depth=None, max_features=None,\n                                             max_leaf_nodes=None,\n                                             min_impurity_decrease=0.0,\n                                             min_impurity_split=None,\n                                             min_samples_leaf=1,\n                                             min_samples_split=2,\n                                             min_weight_fraction_leaf=0.0,\n                                             presort='deprecated',\n                                             random_state=None,\n                                             splitter='best'),\n             iid='deprecated', n_jobs=-1,\n             param_grid=[{'max_depth': [2, 3, 4, 5, 6, 7, 8, 9, 10],\n                          'max_leaf_nodes': [2, 3, 4, 5, 6, 7, 8, 9, 10],\n                          'min_samples_leaf': [5, 6, 7, 8, 9, 10, 11, 12, 13,\n                                               14, 15, 16, 17, 18, 19, 20],\n                          'min_samples_split': [5, 6, 7, 8, 9, 10, 11, 12, 13,\n                                                14, 15, 16, 17, 18, 19, 20]}],\n             pre_dispatch='2*n_jobs', refit=True, return_train_score=False,\n             scoring=None, verbose=2)"
     },
     "metadata": {},
     "execution_count": 15
    }
   ],
   "source": [
    "param_grid = [\n",
    "    {\n",
    "        'max_depth': [i for i in range(2, 11)], \n",
    "        'max_leaf_nodes': [i for i in range(2, 11)],\n",
    "        'min_samples_leaf': [i for i in range(5, 21)],\n",
    "        'min_samples_split': [i for i in range(5, 21)],\n",
    "    }\n",
    "]\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "dt_reg3 = DecisionTreeRegressor()\n",
    "grid_search = GridSearchCV(dt_reg3, param_grid, n_jobs=-1, verbose=2)\n",
    "grid_search.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "0.761369265292256\n{'max_depth': 4, 'max_leaf_nodes': 10, 'min_samples_leaf': 9, 'min_samples_split': 6}\n"
    }
   ],
   "source": [
    "print(grid_search.best_score_)\n",
    "print(grid_search.best_params_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "0.6462430360432516"
     },
     "metadata": {},
     "execution_count": 18
    }
   ],
   "source": [
    "dt_reg_best = grid_search.best_estimator_\n",
    "dt_reg_best.score(X_test, y_test)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.6.8 64-bit ('anaconda': conda)",
   "language": "python",
   "name": "python36864bitanacondaconda085db2e8b14649f4b32196068f7124e9"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}